{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "83b1c45c",
   "metadata": {},
   "source": [
    "# Robustness Against Data Corruption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb0ee212",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# check whether run in Colab\n",
    "root = \".\"\n",
    "if \"google.colab\" in sys.modules:\n",
    "    print(\"Running in Colab.\")\n",
    "    !pip3 install matplotlib\n",
    "    !pip3 install einops==0.3.0\n",
    "    !pip3 install timm==0.4.9\n",
    "    !git clone https://github.com/xxxnell/how-do-vits-work.git\n",
    "    root = \"./how-do-vits-work\"\n",
    "    sys.path.append(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bade143",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "import copy\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "# from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import models\n",
    "# import ops.trains as trains\n",
    "import ops.tests as tests\n",
    "import ops.datasets as datasets\n",
    "# import ops.schedulers as schedulers\n",
    "# import ops.adversarial as adv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c6635ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config_path = \"%s/configs/cifar10_vit.yaml\" % root\n",
    "config_path = \"%s/configs/cifar100_vit.yaml\" % root\n",
    "# config_path = \"%s/configs/imagenet_vit.yaml\" % root\n",
    "\n",
    "with open(config_path) as f:\n",
    "    args = yaml.load(f)\n",
    "    print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93447458",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "train_args = copy.deepcopy(args).get(\"train\")\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "model_args = copy.deepcopy(args).get(\"model\")\n",
    "optim_args = copy.deepcopy(args).get(\"optim\")\n",
    "env_args = copy.deepcopy(args).get(\"env\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0b8f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)\n",
    "dataset_name = dataset_args[\"name\"]\n",
    "num_classes = len(dataset_train.classes)\n",
    "\n",
    "dataset_train = DataLoader(dataset_train, \n",
    "                           shuffle=True, \n",
    "                           num_workers=train_args.get(\"num_workers\", 4), \n",
    "                           batch_size=train_args.get(\"batch_size\", 128))\n",
    "dataset_test = DataLoader(dataset_test, \n",
    "                          num_workers=val_args.get(\"num_workers\", 4), \n",
    "                          batch_size=val_args.get(\"batch_size\", 128))\n",
    "\n",
    "print(\"Train: %s, Test: %s, Classes: %s\" % (\n",
    "    len(dataset_train.dataset), \n",
    "    len(dataset_test.dataset), \n",
    "    num_classes\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad0723a5",
   "metadata": {},
   "source": [
    "## Load Pretrained Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a9da41c",
   "metadata": {},
   "source": [
    "Prepare the pretrained models (the cells below provide the snippets for ResNet-50 and ViT-Ti):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41dd592e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download and load a pretrained model for CIFAR-100\n",
    "url = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar\"\n",
    "path = \"checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar\"\n",
    "models.download(url=url, path=path)\n",
    "\n",
    "name = \"resnet_50\"\n",
    "uid = \"691cc9a9e4\"\n",
    "model = models.get_model(name, num_classes=num_classes,  # timm does not provide a ResNet for CIFAR\n",
    "                         stem=model_args.get(\"stem\", False))\n",
    "map_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "checkpoint = torch.load(path, map_location=map_location)\n",
    "model.load_state_dict(checkpoint[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dece74f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import timm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# download and load a pretrained model for CIFAR-100\n",
    "url = \"https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar\"\n",
    "path = \"checkpoints/vit_ti_cifar100_9857b21357.pth.tar\"\n",
    "models.download(url=url, path=path)\n",
    "\n",
    "model = timm.models.vision_transformer.VisionTransformer(\n",
    "    num_classes=num_classes, img_size=32, patch_size=2,  # for CIFAR\n",
    "    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # for ViT-Ti \n",
    ")\n",
    "model.name = \"vit_ti\"\n",
    "uid = \"9857b21357\"\n",
    "models.stats(model)\n",
    "map_location = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "checkpoint = torch.load(path, map_location=map_location)\n",
    "model.load_state_dict(checkpoint[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dc83468",
   "metadata": {},
   "source": [
    "Parallelize the given `moodel` by splitting the input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ee6c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = model.name\n",
    "model = nn.DataParallel(model)\n",
    "model.name = name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde2a379",
   "metadata": {},
   "source": [
    "Test model performance on in-domain data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7943b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = torch.cuda.is_available()\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_list = []\n",
    "for n_ff in [1]:\n",
    "    print(\"N: %s, \" % n_ff, end=\"\")\n",
    "    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)\n",
    "    metrics_list.append([n_ff, *metrics])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fda55c4e",
   "metadata": {},
   "source": [
    "## Measure Predictive Performances on Corrupted Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c7585d8",
   "metadata": {},
   "source": [
    "We measure predictive performances on 75 corrupted datasets (= datasets corrupted by [15 different types](https://github.com/hendrycks/robustness/blob/master/assets/imagenet-c.png) with 5 levels of intensity each)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95bccfc1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "n_ff = 1\n",
    "\n",
    "gpu = torch.cuda.is_available()\n",
    "val_args = copy.deepcopy(args).get(\"val\")\n",
    "dataset_args = copy.deepcopy(args).get(\"dataset\")\n",
    "\n",
    "model = model.cuda() if gpu else model.cpu()\n",
    "metrics_c = { intensity: {} for intensity in range(1, 6) }\n",
    "for intensity in range(1, 6):\n",
    "    for ctype in datasets.get_corruptions():\n",
    "        dataset_c = datasets.get_dataset_c(**dataset_args, ctype=ctype, intensity=intensity, download=True)\n",
    "        dataset_c = DataLoader(dataset_c, \n",
    "                               num_workers=val_args.get(\"num_workers\", 4), \n",
    "                               batch_size=val_args.get(\"batch_size\", 128))\n",
    "        print(\"Corruption type: %s, Intensity: %d, \" % (ctype, intensity), end=\"\")\n",
    "        *metrics, cal_diag = tests.test(model, n_ff, dataset_c, verbose=False, gpu=gpu)\n",
    "        metrics_c[intensity][ctype] = metrics\n",
    "\n",
    "leaderboard_path = os.path.join(\"leaderboard\", \"logs\", dataset_name, model.name)\n",
    "Path(leaderboard_path).mkdir(parents=True, exist_ok=True)\n",
    "metrics_dir = os.path.join(leaderboard_path, \"%s_%s_%s_%s_corrupted.csv\" % (dataset_name, model.name, uid, n_ff))\n",
    "metrics_c_list = [[i, typ, *metrics] for i, typ_metrics in metrics_c.items() for typ, metrics in typ_metrics.items()]\n",
    "tests.save_metrics(metrics_dir, metrics_c_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "013845ee",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
